/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.grouping;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.BytesRefFieldSource;
import org.apache.lucene.search.CachingCollector;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.mutable.MutableValue;
import org.apache.lucene.util.mutable.MutableValueStr;

// TODO
//   - should test relevance sort too
//   - test null
//   - test ties
//   - test compound sort

public class TestGrouping extends LuceneTestCase {

  public void testBasic() throws Exception {

    String groupField = "author";

    FieldType customType = new FieldType();
    customType.setStored(true);

    Directory dir = newDirectory();
    RandomIndexWriter w =
        new RandomIndexWriter(
            random(),
            dir,
            newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy()));
    // 0
    Document doc = new Document();
    addGroupField(doc, groupField, "author1");
    doc.add(new TextField("content", "random text", Field.Store.YES));
    doc.add(new Field("id", "1", customType));
    w.addDocument(doc);

    // 1
    doc = new Document();
    addGroupField(doc, groupField, "author1");
    doc.add(new TextField("content", "some more random text", Field.Store.YES));
    doc.add(new Field("id", "2", customType));
    w.addDocument(doc);

    // 2
    doc = new Document();
    addGroupField(doc, groupField, "author1");
    doc.add(new TextField("content", "some more random textual data", Field.Store.YES));
    doc.add(new Field("id", "3", customType));
    w.addDocument(doc);

    // 3
    doc = new Document();
    addGroupField(doc, groupField, "author2");
    doc.add(new TextField("content", "some random text", Field.Store.YES));
    doc.add(new Field("id", "4", customType));
    w.addDocument(doc);

    // 4
    doc = new Document();
    addGroupField(doc, groupField, "author3");
    doc.add(new TextField("content", "some more random text", Field.Store.YES));
    doc.add(new Field("id", "5", customType));
    w.addDocument(doc);

    // 5
    doc = new Document();
    addGroupField(doc, groupField, "author3");
    doc.add(new TextField("content", "random", Field.Store.YES));
    doc.add(new Field("id", "6", customType));
    w.addDocument(doc);

    // 6 -- no author field
    doc = new Document();
    doc.add(new TextField("content", "random word stuck in alot of other text", Field.Store.YES));
    doc.add(new Field("id", "6", customType));
    w.addDocument(doc);

    IndexSearcher indexSearcher = newSearcher(w.getReader());
    // This test relies on the fact that longer fields produce lower scores
    indexSearcher.setSimilarity(new BM25Similarity());
    w.close();

    final Sort groupSort = Sort.RELEVANCE;

    final FirstPassGroupingCollector<?> c1 =
        createRandomFirstPassCollector(groupField, groupSort, 10);
    indexSearcher.search(new TermQuery(new Term("content", "random")), c1);

    final TopGroupsCollector<?> c2 =
        createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true);
    indexSearcher.search(new TermQuery(new Term("content", "random")), c2);

    final TopGroups<?> groups = c2.getTopGroups(0);
    assertFalse(Float.isNaN(groups.maxScore));

    assertEquals(7, groups.totalHitCount);
    assertEquals(7, groups.totalGroupedHitCount);
    assertEquals(4, groups.groups.length);

    // relevance order: 5, 0, 3, 4, 1, 2, 6

    // the later a document is added the higher this docId
    // value
    GroupDocs<?> group = groups.groups[0];
    compareGroupValue("author3", group);
    assertEquals(2, group.scoreDocs().length);
    assertEquals(5, group.scoreDocs()[0].doc);
    assertEquals(4, group.scoreDocs()[1].doc);
    assertTrue(group.scoreDocs()[0].score > group.scoreDocs()[1].score);

    group = groups.groups[1];
    compareGroupValue("author1", group);
    assertEquals(3, group.scoreDocs().length);
    assertEquals(0, group.scoreDocs()[0].doc);
    assertEquals(1, group.scoreDocs()[1].doc);
    assertEquals(2, group.scoreDocs()[2].doc);
    assertTrue(group.scoreDocs()[0].score >= group.scoreDocs()[1].score);
    assertTrue(group.scoreDocs()[1].score >= group.scoreDocs()[2].score);

    group = groups.groups[2];
    compareGroupValue("author2", group);
    assertEquals(1, group.scoreDocs().length);
    assertEquals(3, group.scoreDocs()[0].doc);

    group = groups.groups[3];
    compareGroupValue(null, group);
    assertEquals(1, group.scoreDocs().length);
    assertEquals(6, group.scoreDocs()[0].doc);

    indexSearcher.getIndexReader().close();
    dir.close();
  }

  public void testIgnoreDocsWithoutGroupField() throws IOException {
    Directory dir = newDirectory();
    RandomIndexWriter w =
        new RandomIndexWriter(random(), dir, newIndexWriterConfig(new MockAnalyzer(random())));

    String groupField = "group";
    // Add documents with group field
    Document doc = new Document();
    addGroupField(doc, groupField, "group1");
    // doc.add(new SortedDocValuesField("group", new BytesRef("group1")));
    doc.add(new TextField("content", "test", Field.Store.YES));
    w.addDocument(doc);

    doc = new Document();
    addGroupField(doc, groupField, "group2");
    doc.add(new TextField("content", "test", Field.Store.YES));
    w.addDocument(doc);

    // Add document without group field
    doc = new Document();
    doc.add(new TextField("content", "test", Field.Store.YES));
    w.addDocument(doc);

    DirectoryReader reader = w.getReader();
    w.close();

    IndexSearcher searcher = newSearcher(reader);

    // Test default behavior (include null group)
    FirstPassGroupingCollector<BytesRef> collector1 =
        new FirstPassGroupingCollector<>(new TermGroupSelector(groupField), Sort.RELEVANCE, 10);
    searcher.search(new MatchAllDocsQuery(), collector1);
    Collection<SearchGroup<BytesRef>> groups1 = collector1.getTopGroups(0);

    assertEquals(3, groups1.size()); // Should include null group

    // Test ignoring docs without group field
    FirstPassGroupingCollector<BytesRef> collector2 =
        new FirstPassGroupingCollector<>(
            new TermGroupSelector(groupField), Sort.RELEVANCE, 10, true);
    searcher.search(new MatchAllDocsQuery(), collector2);
    Collection<SearchGroup<BytesRef>> groups2 = collector2.getTopGroups(0);

    assertEquals(2, groups2.size()); // Should exclude null group

    reader.close();
    dir.close();
  }

  public void testAllDocsWithoutGroupField() throws IOException {
    Directory dir = newDirectory();
    RandomIndexWriter w =
        new RandomIndexWriter(random(), dir, newIndexWriterConfig(new MockAnalyzer(random())));

    // Add documents without group field
    for (int i = 0; i < 5; i++) {
      Document doc = new Document();
      doc.add(new TextField("content", "test", Field.Store.YES));
      w.addDocument(doc);
    }

    DirectoryReader reader = w.getReader();
    w.close();

    IndexSearcher searcher = newSearcher(reader);

    // Test ignoring docs without group field when all docs lack the field
    FirstPassGroupingCollector<BytesRef> collector =
        new FirstPassGroupingCollector<>(new TermGroupSelector("group"), Sort.RELEVANCE, 10, true);
    searcher.search(new MatchAllDocsQuery(), collector);
    Collection<SearchGroup<BytesRef>> groups = collector.getTopGroups(0);

    assertNull(groups); // Should return null when no groups found

    reader.close();
    dir.close();
  }

  private void addGroupField(Document doc, String groupField, String value) {
    doc.add(new SortedDocValuesField(groupField, new BytesRef(value)));
  }

  private FirstPassGroupingCollector<?> createRandomFirstPassCollector(
      String groupField, Sort groupSort, int topDocs) throws IOException {
    if (random().nextBoolean()) {
      ValueSource vs = new BytesRefFieldSource(groupField);
      return new FirstPassGroupingCollector<>(
          new ValueSourceGroupSelector(vs, new HashMap<>()), groupSort, topDocs);
    } else {
      return new FirstPassGroupingCollector<>(
          new TermGroupSelector(groupField), groupSort, topDocs);
    }
  }

  private FirstPassGroupingCollector<?> createFirstPassCollector(
      String groupField,
      Sort groupSort,
      int topDocs,
      FirstPassGroupingCollector<?> firstPassGroupingCollector)
      throws IOException {
    GroupSelector<?> selector = firstPassGroupingCollector.getGroupSelector();
    if (TermGroupSelector.class.isAssignableFrom(selector.getClass())) {
      ValueSource vs = new BytesRefFieldSource(groupField);
      return new FirstPassGroupingCollector<>(
          new ValueSourceGroupSelector(vs, new HashMap<>()), groupSort, topDocs);
    } else {
      return new FirstPassGroupingCollector<>(
          new TermGroupSelector(groupField), groupSort, topDocs);
    }
  }

  @SuppressWarnings({"unchecked", "rawtypes"})
  private <T> TopGroupsCollector<T> createSecondPassCollector(
      FirstPassGroupingCollector firstPassGroupingCollector,
      Sort groupSort,
      Sort sortWithinGroup,
      int groupOffset,
      int maxDocsPerGroup,
      boolean getMaxScores)
      throws IOException {

    Collection<SearchGroup<T>> searchGroups = firstPassGroupingCollector.getTopGroups(groupOffset);
    return new TopGroupsCollector<>(
        firstPassGroupingCollector.getGroupSelector(),
        searchGroups,
        groupSort,
        sortWithinGroup,
        maxDocsPerGroup,
        getMaxScores);
  }

  // Basically converts searchGroups from MutableValue to BytesRef if grouping by ValueSource
  @SuppressWarnings("unchecked")
  private TopGroupsCollector<?> createSecondPassCollector(
      FirstPassGroupingCollector<?> firstPassGroupingCollector,
      String groupField,
      Collection<SearchGroup<BytesRef>> searchGroups,
      Sort groupSort,
      Sort sortWithinGroup,
      int maxDocsPerGroup,
      boolean getMaxScores)
      throws IOException {
    if (firstPassGroupingCollector
        .getGroupSelector()
        .getClass()
        .isAssignableFrom(TermGroupSelector.class)) {
      GroupSelector<BytesRef> selector =
          (GroupSelector<BytesRef>) firstPassGroupingCollector.getGroupSelector();
      return new TopGroupsCollector<>(
          selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
    } else {
      ValueSource vs = new BytesRefFieldSource(groupField);
      List<SearchGroup<MutableValue>> mvalSearchGroups = new ArrayList<>(searchGroups.size());
      for (SearchGroup<BytesRef> mergedTopGroup : searchGroups) {
        SearchGroup<MutableValue> sg = new SearchGroup<>();
        MutableValueStr groupValue = new MutableValueStr();
        if (mergedTopGroup.groupValue != null) {
          groupValue.value.copyBytes(mergedTopGroup.groupValue);
        } else {
          groupValue.exists = false;
        }
        sg.groupValue = groupValue;
        sg.sortValues = mergedTopGroup.sortValues;
        mvalSearchGroups.add(sg);
      }
      ValueSourceGroupSelector selector = new ValueSourceGroupSelector(vs, new HashMap<>());
      return new TopGroupsCollector<>(
          selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
    }
  }

  private AllGroupsCollector<?> createAllGroupsCollector(
      FirstPassGroupingCollector<?> firstPassGroupingCollector, String groupField) {
    return new AllGroupsCollector<>(firstPassGroupingCollector.getGroupSelector());
  }

  private void compareGroupValue(String expected, GroupDocs<?> group) {
    if (expected == null) {
      if (group.groupValue() == null) {
        return;
      } else if (group.groupValue().getClass().isAssignableFrom(MutableValueStr.class)) {
        return;
      } else if (((BytesRef) group.groupValue()).length == 0) {
        return;
      }
      fail();
    }

    if (group.groupValue().getClass().isAssignableFrom(BytesRef.class)) {
      assertEquals(new BytesRef(expected), group.groupValue());
    } else if (group.groupValue().getClass().isAssignableFrom(MutableValueStr.class)) {
      MutableValueStr v = new MutableValueStr();
      v.value.copyChars(expected);
      assertEquals(v, group.groupValue());
    } else {
      fail();
    }
  }

  private Collection<SearchGroup<BytesRef>> getSearchGroups(
      FirstPassGroupingCollector<?> c, int groupOffset) throws IOException {
    if (TermGroupSelector.class.isAssignableFrom(c.getGroupSelector().getClass())) {
      @SuppressWarnings("unchecked")
      FirstPassGroupingCollector<BytesRef> collector = (FirstPassGroupingCollector<BytesRef>) c;
      return collector.getTopGroups(groupOffset);
    } else if (ValueSourceGroupSelector.class.isAssignableFrom(c.getGroupSelector().getClass())) {
      @SuppressWarnings("unchecked")
      FirstPassGroupingCollector<MutableValue> collector =
          (FirstPassGroupingCollector<MutableValue>) c;
      Collection<SearchGroup<MutableValue>> mutableValueGroups =
          collector.getTopGroups(groupOffset);
      if (mutableValueGroups == null) {
        return null;
      }

      List<SearchGroup<BytesRef>> groups = new ArrayList<>(mutableValueGroups.size());
      for (SearchGroup<MutableValue> mutableValueGroup : mutableValueGroups) {
        SearchGroup<BytesRef> sg = new SearchGroup<>();
        sg.groupValue =
            mutableValueGroup.groupValue.exists()
                ? ((MutableValueStr) mutableValueGroup.groupValue).value.get()
                : null;
        sg.sortValues = mutableValueGroup.sortValues;
        groups.add(sg);
      }
      return groups;
    }
    fail();
    return null;
  }

  @SuppressWarnings({"unchecked", "rawtypes"})
  private TopGroups<BytesRef> getTopGroups(TopGroupsCollector c, int withinGroupOffset) {
    if (c.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) {
      TopGroupsCollector<BytesRef> collector = (TopGroupsCollector<BytesRef>) c;
      return collector.getTopGroups(withinGroupOffset);
    } else if (c.getGroupSelector().getClass().isAssignableFrom(ValueSourceGroupSelector.class)) {
      TopGroupsCollector<MutableValue> collector = (TopGroupsCollector<MutableValue>) c;
      TopGroups<MutableValue> mvalTopGroups = collector.getTopGroups(withinGroupOffset);
      List<GroupDocs<BytesRef>> groups = new ArrayList<>(mvalTopGroups.groups.length);
      for (GroupDocs<MutableValue> mvalGd : mvalTopGroups.groups) {
        BytesRef groupValue =
            mvalGd.groupValue().exists()
                ? ((MutableValueStr) mvalGd.groupValue()).value.get()
                : null;
        groups.add(
            new GroupDocs<>(
                Float.NaN,
                mvalGd.maxScore(),
                mvalGd.totalHits(),
                mvalGd.scoreDocs(),
                groupValue,
                mvalGd.groupSortValues()));
      }
      // NOTE: currenlty using diamond operator on MergedIterator (without explicit Term class)
      // causes
      // errors on Eclipse Compiler (ecj) used for javadoc lint
      return new TopGroups<>(
          mvalTopGroups.groupSort,
          mvalTopGroups.withinGroupSort,
          mvalTopGroups.totalHitCount,
          mvalTopGroups.totalGroupedHitCount,
          groups.toArray(new GroupDocs[groups.size()]),
          Float.NaN);
    }
    fail();
    return null;
  }

  private static class GroupDoc {
    final int id;
    final BytesRef group;
    final BytesRef sort1;
    final BytesRef sort2;
    // content must be "realN ..."
    final String content;
    float score;
    float score2;

    public GroupDoc(int id, BytesRef group, BytesRef sort1, BytesRef sort2, String content) {
      this.id = id;
      this.group = group;
      this.sort1 = sort1;
      this.sort2 = sort2;
      this.content = content;
    }
  }

  private Sort getRandomSort() {
    final List<SortField> sortFields = new ArrayList<>();
    if (random().nextInt(7) == 2) {
      sortFields.add(SortField.FIELD_SCORE);
    } else {
      if (random().nextBoolean()) {
        if (random().nextBoolean()) {
          sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
        } else {
          sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
        }
      } else if (random().nextBoolean()) {
        sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
        sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
      }
    }
    // Break ties:
    sortFields.add(new SortField("id", SortField.Type.INT));
    return new Sort(sortFields.toArray(new SortField[sortFields.size()]));
  }

  private Comparator<GroupDoc> getComparator(Sort sort) {
    final SortField[] sortFields = sort.getSort();
    return new Comparator<>() {
      @Override
      public int compare(GroupDoc d1, GroupDoc d2) {
        for (SortField sf : sortFields) {
          final int cmp;
          if (sf.getType() == SortField.Type.SCORE) {
            if (d1.score > d2.score) {
              cmp = -1;
            } else if (d1.score < d2.score) {
              cmp = 1;
            } else {
              cmp = 0;
            }
          } else if (sf.getField().equals("sort1")) {
            cmp = d1.sort1.compareTo(d2.sort1);
          } else if (sf.getField().equals("sort2")) {
            cmp = d1.sort2.compareTo(d2.sort2);
          } else {
            assertEquals(sf.getField(), "id");
            cmp = d1.id - d2.id;
          }
          if (cmp != 0) {
            return sf.getReverse() ? -cmp : cmp;
          }
        }
        // Our sort always fully tie breaks:
        fail();
        return 0;
      }
    };
  }

  @SuppressWarnings({"unchecked", "rawtypes"})
  private Comparable<?>[] fillFields(GroupDoc d, Sort sort) {
    final SortField[] sortFields = sort.getSort();
    final Comparable<?>[] fields = new Comparable[sortFields.length];
    for (int fieldIDX = 0; fieldIDX < sortFields.length; fieldIDX++) {
      final Comparable<?> c;
      final SortField sf = sortFields[fieldIDX];
      if (sf.getType() == SortField.Type.SCORE) {
        c = d.score;
      } else if (sf.getField().equals("sort1")) {
        c = d.sort1;
      } else if (sf.getField().equals("sort2")) {
        c = d.sort2;
      } else {
        assertEquals("id", sf.getField());
        c = d.id;
      }
      fields[fieldIDX] = c;
    }
    return fields;
  }

  private String groupToString(BytesRef b) {
    if (b == null) {
      return "null";
    } else {
      return b.utf8ToString();
    }
  }

  private TopGroups<BytesRef> slowGrouping(
      GroupDoc[] groupDocs,
      String searchTerm,
      boolean getMaxScores,
      boolean doAllGroups,
      Sort groupSort,
      Sort docSort,
      int topNGroups,
      int docsPerGroup,
      int groupOffset,
      int docOffset) {

    final Comparator<GroupDoc> groupSortComp = getComparator(groupSort);

    Arrays.sort(groupDocs, groupSortComp);
    final HashMap<BytesRef, List<GroupDoc>> groups = new HashMap<>();
    final List<BytesRef> sortedGroups = new ArrayList<>();
    final List<Comparable<?>[]> sortedGroupFields = new ArrayList<>();

    int totalHitCount = 0;
    Set<BytesRef> knownGroups = new HashSet<>();

    // System.out.println("TEST: slowGrouping");
    for (GroupDoc d : groupDocs) {
      // TODO: would be better to filter by searchTerm before sorting!
      if (!d.content.startsWith(searchTerm)) {
        continue;
      }
      totalHitCount++;
      // System.out.println("  match id=" + d.id + " score=" + d.score);

      if (doAllGroups) {
        if (!knownGroups.contains(d.group)) {
          knownGroups.add(d.group);
          // System.out.println("    add group=" + groupToString(d.group));
        }
      }

      List<GroupDoc> l = groups.get(d.group);
      if (l == null) {
        // System.out.println("    add sortedGroup=" + groupToString(d.group));
        sortedGroups.add(d.group);
        sortedGroupFields.add(fillFields(d, groupSort));
        l = new ArrayList<>();
        groups.put(d.group, l);
      }
      l.add(d);
    }

    if (groupOffset >= sortedGroups.size()) {
      // slice is out of bounds
      return null;
    }

    final int limit = Math.min(groupOffset + topNGroups, groups.size());

    final Comparator<GroupDoc> docSortComp = getComparator(docSort);
    @SuppressWarnings({"unchecked", "rawtypes"})
    final GroupDocs<BytesRef>[] result = new GroupDocs[limit - groupOffset];
    int totalGroupedHitCount = 0;
    for (int idx = groupOffset; idx < limit; idx++) {
      final BytesRef group = sortedGroups.get(idx);
      final List<GroupDoc> docs = groups.get(group);
      totalGroupedHitCount += docs.size();
      Collections.sort(docs, docSortComp);
      final ScoreDoc[] hits;
      if (docs.size() > docOffset) {
        final int docIDXLimit = Math.min(docOffset + docsPerGroup, docs.size());
        hits = new ScoreDoc[docIDXLimit - docOffset];
        for (int docIDX = docOffset; docIDX < docIDXLimit; docIDX++) {
          final GroupDoc d = docs.get(docIDX);
          final FieldDoc fd;
          fd = new FieldDoc(d.id, Float.NaN, fillFields(d, docSort));
          hits[docIDX - docOffset] = fd;
        }
      } else {
        hits = new ScoreDoc[0];
      }

      result[idx - groupOffset] =
          new GroupDocs<>(
              Float.NaN,
              0.0f,
              new TotalHits(docs.size(), TotalHits.Relation.EQUAL_TO),
              hits,
              group,
              sortedGroupFields.get(idx));
    }

    if (doAllGroups) {
      return new TopGroups<>(
          new TopGroups<>(
              groupSort.getSort(),
              docSort.getSort(),
              totalHitCount,
              totalGroupedHitCount,
              result,
              Float.NaN),
          knownGroups.size());
    } else {
      return new TopGroups<>(
          groupSort.getSort(),
          docSort.getSort(),
          totalHitCount,
          totalGroupedHitCount,
          result,
          Float.NaN);
    }
  }

  private DirectoryReader getDocBlockReader(Directory dir, GroupDoc[] groupDocs)
      throws IOException {
    // Coalesce by group, but in random order:
    Collections.shuffle(Arrays.asList(groupDocs), random());
    final Map<BytesRef, List<GroupDoc>> groupMap = new HashMap<>();
    final List<BytesRef> groupValues = new ArrayList<>();

    for (GroupDoc groupDoc : groupDocs) {
      if (!groupMap.containsKey(groupDoc.group)) {
        groupValues.add(groupDoc.group);
        groupMap.put(groupDoc.group, new ArrayList<>());
      }
      groupMap.get(groupDoc.group).add(groupDoc);
    }

    RandomIndexWriter w =
        new RandomIndexWriter(
            random(),
            dir,
            newIndexWriterConfig(new MockAnalyzer(random()))
                .setMergePolicy(newMergePolicy(random(), false)));

    final List<List<Document>> updateDocs = new ArrayList<>();

    FieldType groupEndType = new FieldType(StringField.TYPE_NOT_STORED);
    groupEndType.setIndexOptions(IndexOptions.DOCS);
    groupEndType.setOmitNorms(true);

    // System.out.println("TEST: index groups");
    for (BytesRef group : groupValues) {
      final List<Document> docs = new ArrayList<>();
      // System.out.println("TEST:   group=" + (group == null ? "null" : group.utf8ToString()));
      for (GroupDoc groupValue : groupMap.get(group)) {
        Document doc = new Document();
        docs.add(doc);
        if (groupValue.group != null) {
          doc.add(newStringField("group", groupValue.group.utf8ToString(), Field.Store.YES));
          doc.add(new SortedDocValuesField("group", BytesRef.deepCopyOf(groupValue.group)));
        }
        doc.add(newStringField("sort1", groupValue.sort1.utf8ToString(), Field.Store.NO));
        doc.add(new SortedDocValuesField("sort1", BytesRef.deepCopyOf(groupValue.sort1)));
        doc.add(newStringField("sort2", groupValue.sort2.utf8ToString(), Field.Store.NO));
        doc.add(new SortedDocValuesField("sort2", BytesRef.deepCopyOf(groupValue.sort2)));
        doc.add(new NumericDocValuesField("id", groupValue.id));
        doc.add(newTextField("content", groupValue.content, Field.Store.NO));
        // System.out.println("TEST:     doc content=" + groupValue.content + " group=" +
        // (groupValue.group == null ? "null" : groupValue.group.utf8ToString()) + " sort1=" +
        // groupValue.sort1.utf8ToString() + " id=" + groupValue.id);
      }
      // So we can pull filter marking last doc in block:
      final Field groupEnd = newField("groupend", "x", groupEndType);
      docs.get(docs.size() - 1).add(groupEnd);
      // Add as a doc block:
      w.addDocuments(docs);
      if (group != null && random().nextInt(7) == 4) {
        updateDocs.add(docs);
      }
    }

    for (List<Document> docs : updateDocs) {
      // Just replaces docs w/ same docs:
      w.updateDocuments(new Term("group", docs.get(0).get("group")), docs);
    }

    final DirectoryReader r = w.getReader();
    w.close();

    return r;
  }

  private static class ShardState {

    public final ShardSearcher[] subSearchers;
    public final int[] docStarts;

    public ShardState(IndexSearcher s) {
      final IndexReaderContext ctx = s.getTopReaderContext();
      final List<LeafReaderContext> leaves = ctx.leaves();
      subSearchers = new ShardSearcher[leaves.size()];
      for (int searcherIDX = 0; searcherIDX < subSearchers.length; searcherIDX++) {
        subSearchers[searcherIDX] = new ShardSearcher(leaves.get(searcherIDX), ctx);
      }

      docStarts = new int[subSearchers.length];
      for (int subIDX = 0; subIDX < docStarts.length; subIDX++) {
        docStarts[subIDX] = leaves.get(subIDX).docBase;
        // System.out.println("docStarts[" + subIDX + "]=" + docStarts[subIDX]);
      }
    }
  }

  public void testRandom() throws Exception {
    int numberOfRuns = atLeast(1);
    for (int iter = 0; iter < numberOfRuns; iter++) {
      if (VERBOSE) {
        System.out.println("TEST: iter=" + iter);
      }

      final int numDocs = atLeast(100);
      // final int numDocs = _TestUtil.nextInt(random, 5, 20);

      final int numGroups = TestUtil.nextInt(random(), 1, numDocs);

      if (VERBOSE) {
        System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups);
      }

      final List<BytesRef> groups = new ArrayList<>();
      for (int i = 0; i < numGroups; i++) {
        String randomValue;
        do {
          // B/c of DV based impl we can't see the difference between an empty string and a null
          // value.
          // For that reason we don't generate empty string
          // groups.
          randomValue = TestUtil.randomRealisticUnicodeString(random());
          // randomValue = TestUtil.randomSimpleString(random());
        } while (randomValue.isEmpty());

        groups.add(new BytesRef(randomValue));
      }
      final String[] contentStrings = new String[TestUtil.nextInt(random(), 2, 20)];
      if (VERBOSE) {
        System.out.println("TEST: create fake content");
      }
      for (int contentIDX = 0; contentIDX < contentStrings.length; contentIDX++) {
        final StringBuilder sb = new StringBuilder();
        sb.append("real").append(random().nextInt(3)).append(' ');
        final int fakeCount = random().nextInt(10);
        for (int fakeIDX = 0; fakeIDX < fakeCount; fakeIDX++) {
          sb.append("fake ");
        }
        contentStrings[contentIDX] = sb.toString();
        if (VERBOSE) {
          System.out.println("  content=" + sb.toString());
        }
      }

      Directory dir = newDirectory();
      RandomIndexWriter w =
          new RandomIndexWriter(
              random(),
              dir,
              newIndexWriterConfig(new MockAnalyzer(random()))
                  .setMergePolicy(newMergePolicy(random(), false)));
      Document doc = new Document();
      Document docNoGroup = new Document();
      Field idvGroupField = new SortedDocValuesField("group", new BytesRef());
      doc.add(idvGroupField);
      docNoGroup.add(idvGroupField);

      Field group = newStringField("group", "", Field.Store.NO);
      doc.add(group);
      docNoGroup.add(group);
      Field sort1 = new SortedDocValuesField("sort1", new BytesRef());
      doc.add(sort1);
      docNoGroup.add(sort1);
      Field sort2 = new SortedDocValuesField("sort2", new BytesRef());
      doc.add(sort2);
      docNoGroup.add(sort2);
      Field content = newTextField("content", "", Field.Store.NO);
      doc.add(content);
      docNoGroup.add(content);
      NumericDocValuesField idDV = new NumericDocValuesField("id", 0);
      doc.add(idDV);
      docNoGroup.add(idDV);
      final GroupDoc[] groupDocs = new GroupDoc[numDocs];
      for (int i = 0; i < numDocs; i++) {
        final BytesRef groupValue;
        if (random().nextInt(24) == 17) {
          // So we test the "doc doesn't have the group'd
          // field" case:
          groupValue = null;
        } else {
          groupValue = groups.get(random().nextInt(groups.size()));
        }
        final GroupDoc groupDoc =
            new GroupDoc(
                i,
                groupValue,
                groups.get(random().nextInt(groups.size())),
                groups.get(random().nextInt(groups.size())),
                contentStrings[random().nextInt(contentStrings.length)]);
        if (VERBOSE) {
          System.out.println(
              "  doc content="
                  + groupDoc.content
                  + " id="
                  + i
                  + " group="
                  + (groupDoc.group == null ? "null" : groupDoc.group.utf8ToString())
                  + " sort1="
                  + groupDoc.sort1.utf8ToString()
                  + " sort2="
                  + groupDoc.sort2.utf8ToString());
        }

        groupDocs[i] = groupDoc;
        if (groupDoc.group != null) {
          group.setStringValue(groupDoc.group.utf8ToString());
          idvGroupField.setBytesValue(BytesRef.deepCopyOf(groupDoc.group));
        } else {
          // TODO: not true
          // Must explicitly set empty string, else eg if
          // the segment has all docs missing the field then
          // we get null back instead of empty BytesRef:
          idvGroupField.setBytesValue(new BytesRef());
        }
        sort1.setBytesValue(BytesRef.deepCopyOf(groupDoc.sort1));
        sort2.setBytesValue(BytesRef.deepCopyOf(groupDoc.sort2));
        content.setStringValue(groupDoc.content);
        idDV.setLongValue(groupDoc.id);
        if (groupDoc.group == null) {
          w.addDocument(docNoGroup);
        } else {
          w.addDocument(doc);
        }
      }

      final GroupDoc[] groupDocsByID = new GroupDoc[groupDocs.length];
      System.arraycopy(groupDocs, 0, groupDocsByID, 0, groupDocs.length);

      final DirectoryReader r = w.getReader();
      w.close();

      NumericDocValues values = MultiDocValues.getNumericValues(r, "id");
      int[] docIDToID = new int[r.maxDoc()];
      for (int i = 0; i < r.maxDoc(); i++) {
        assertEquals(i, values.nextDoc());
        docIDToID[i] = (int) values.longValue();
      }
      DirectoryReader rBlocks = null;
      Directory dirBlocks = null;

      final IndexSearcher s = newSearcher(r);
      // This test relies on the fact that longer fields produce lower scores
      s.setSimilarity(new BM25Similarity());

      if (VERBOSE) {
        System.out.println("\nTEST: searcher=" + s);
      }

      final ShardState shards = new ShardState(s);

      Set<Integer> seenIDs = new HashSet<>();
      for (int contentID = 0; contentID < 3; contentID++) {
        final ScoreDoc[] hits =
            s.search(new TermQuery(new Term("content", "real" + contentID)), numDocs).scoreDocs;
        for (ScoreDoc hit : hits) {
          int idValue = docIDToID[hit.doc];

          final GroupDoc gd = groupDocs[idValue];
          seenIDs.add(idValue);
          assertTrue(gd.score == 0.0);
          gd.score = hit.score;
          assertEquals(gd.id, idValue);
        }
      }

      // make sure all groups were seen across the hits
      assertEquals(groupDocs.length, seenIDs.size());

      for (GroupDoc gd : groupDocs) {
        assertTrue(Float.isFinite(gd.score));
        assertTrue(gd.score >= 0.0);
      }

      // Build 2nd index, where docs are added in blocks by
      // group, so we can use single pass collector
      dirBlocks = newDirectory();
      rBlocks = getDocBlockReader(dirBlocks, groupDocs);
      final Query lastDocInBlock = new TermQuery(new Term("groupend", "x"));

      final IndexSearcher sBlocks = newSearcher(rBlocks);
      // This test relies on the fact that longer fields produce lower scores
      sBlocks.setSimilarity(new BM25Similarity());

      final ShardState shardsBlocks = new ShardState(sBlocks);

      // ReaderBlocks only increases maxDoc() vs reader, which
      // means a monotonic shift in scores, so we can
      // reliably remap them w/ Map:
      final Map<String, Map<Float, Float>> scoreMap = new HashMap<>();

      values = MultiDocValues.getNumericValues(rBlocks, "id");
      assertNotNull(values);
      int[] docIDToIDBlocks = new int[rBlocks.maxDoc()];
      for (int i = 0; i < rBlocks.maxDoc(); i++) {
        assertEquals(i, values.nextDoc());
        docIDToIDBlocks[i] = (int) values.longValue();
      }

      // Tricky: must separately set .score2, because the doc
      // block index was created with possible deletions!
      // System.out.println("fixup score2");
      for (int contentID = 0; contentID < 3; contentID++) {
        // System.out.println("  term=real" + contentID);
        final Map<Float, Float> termScoreMap = new HashMap<>();
        scoreMap.put("real" + contentID, termScoreMap);
        // System.out.println("term=real" + contentID + " dfold=" + s.docFreq(new Term("content",
        // "real"+contentID)) +
        // " dfnew=" + sBlocks.docFreq(new Term("content", "real"+contentID)));
        final ScoreDoc[] hits =
            sBlocks.search(new TermQuery(new Term("content", "real" + contentID)), numDocs)
                .scoreDocs;
        for (ScoreDoc hit : hits) {
          final GroupDoc gd = groupDocsByID[docIDToIDBlocks[hit.doc]];
          assertTrue(gd.score2 == 0.0);
          gd.score2 = hit.score;
          assertEquals(gd.id, docIDToIDBlocks[hit.doc]);
          // System.out.println("    score=" + gd.score + " score2=" + hit.score + " id=" +
          // docIDToIDBlocks[hit.doc]);
          termScoreMap.put(gd.score, gd.score2);
        }
      }

      for (int searchIter = 0; searchIter < 100; searchIter++) {

        if (VERBOSE) {
          System.out.println("\nTEST: searchIter=" + searchIter);
        }

        final String searchTerm = "real" + random().nextInt(3);
        final boolean getMaxScores = random().nextBoolean();
        final Sort groupSort = getRandomSort();
        // final Sort groupSort = new Sort(new SortField[] {new SortField("sort1",
        // SortField.STRING), new SortField("id", SortField.INT)});
        final Sort docSort = getRandomSort();

        final int topNGroups = TestUtil.nextInt(random(), 1, 30);
        // final int topNGroups = 10;
        final int docsPerGroup = TestUtil.nextInt(random(), 1, 50);

        final int groupOffset = TestUtil.nextInt(random(), 0, (topNGroups - 1) / 2);
        // final int groupOffset = 0;

        final int docOffset = TestUtil.nextInt(random(), 0, docsPerGroup - 1);
        // final int docOffset = 0;

        final boolean doCache = random().nextBoolean();
        final boolean doAllGroups = random().nextBoolean();
        if (VERBOSE) {
          System.out.println(
              "TEST: groupSort="
                  + groupSort
                  + " docSort="
                  + docSort
                  + " searchTerm="
                  + searchTerm
                  + " dF="
                  + r.docFreq(new Term("content", searchTerm))
                  + " dFBlock="
                  + rBlocks.docFreq(new Term("content", searchTerm))
                  + " topNGroups="
                  + topNGroups
                  + " groupOffset="
                  + groupOffset
                  + " docOffset="
                  + docOffset
                  + " doCache="
                  + doCache
                  + " docsPerGroup="
                  + docsPerGroup
                  + " doAllGroups="
                  + doAllGroups
                  + " getMaxScores="
                  + getMaxScores);
        }

        String groupField = "group";
        if (VERBOSE) {
          System.out.println("  groupField=" + groupField);
        }
        final FirstPassGroupingCollector<?> c1 =
            createRandomFirstPassCollector(groupField, groupSort, groupOffset + topNGroups);
        final CachingCollector cCache;
        final Collector c;

        final AllGroupsCollector<?> allGroupsCollector;
        if (doAllGroups) {
          allGroupsCollector = createAllGroupsCollector(c1, groupField);
        } else {
          allGroupsCollector = null;
        }

        final boolean useWrappingCollector = random().nextBoolean();

        if (doCache) {
          final double maxCacheMB = random().nextDouble();
          if (VERBOSE) {
            System.out.println("TEST: maxCacheMB=" + maxCacheMB);
          }

          if (useWrappingCollector) {
            if (doAllGroups) {
              cCache = CachingCollector.create(c1, true, maxCacheMB);
              c = MultiCollector.wrap(cCache, allGroupsCollector);
            } else {
              c = cCache = CachingCollector.create(c1, true, maxCacheMB);
            }
          } else {
            // Collect only into cache, then replay multiple times:
            c = cCache = CachingCollector.create(true, maxCacheMB);
          }
        } else {
          cCache = null;
          if (doAllGroups) {
            c = MultiCollector.wrap(c1, allGroupsCollector);
          } else {
            c = c1;
          }
        }

        // Search top reader:
        final Query query = new TermQuery(new Term("content", searchTerm));

        s.search(query, c);

        if (doCache && !useWrappingCollector) {
          if (cCache.isCached()) {
            // Replay for first-pass grouping
            cCache.replay(c1);
            if (doAllGroups) {
              // Replay for all groups:
              cCache.replay(allGroupsCollector);
            }
          } else {
            // Replay by re-running search:
            s.search(query, c1);
            if (doAllGroups) {
              s.search(query, allGroupsCollector);
            }
          }
        }

        // Get 1st pass top groups
        final Collection<SearchGroup<BytesRef>> topGroups = getSearchGroups(c1, groupOffset);
        final TopGroups<BytesRef> groupsResult;
        if (VERBOSE) {
          System.out.println("TEST: first pass topGroups");
          if (topGroups == null) {
            System.out.println("  null");
          } else {
            for (SearchGroup<BytesRef> searchGroup : topGroups) {
              System.out.println(
                  "  "
                      + (searchGroup.groupValue == null ? "null" : searchGroup.groupValue)
                      + ": "
                      + Arrays.deepToString(searchGroup.sortValues));
            }
          }
        }

        // Get 1st pass top groups using shards

        final TopGroups<BytesRef> topGroupsShards =
            searchShards(
                s,
                shards.subSearchers,
                query,
                groupSort,
                docSort,
                groupOffset,
                topNGroups,
                docOffset,
                docsPerGroup,
                getMaxScores,
                true,
                true);
        final TopGroupsCollector<?> c2;
        if (topGroups != null) {

          if (VERBOSE) {
            System.out.println("TEST: topGroups");
            for (SearchGroup<BytesRef> searchGroup : topGroups) {
              System.out.println(
                  "  "
                      + (searchGroup.groupValue == null
                          ? "null"
                          : searchGroup.groupValue.utf8ToString())
                      + ": "
                      + Arrays.deepToString(searchGroup.sortValues));
            }
          }

          c2 =
              createSecondPassCollector(
                  c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getMaxScores);
          if (doCache) {
            if (cCache.isCached()) {
              if (VERBOSE) {
                System.out.println("TEST: cache is intact");
              }
              cCache.replay(c2);
            } else {
              if (VERBOSE) {
                System.out.println("TEST: cache was too large");
              }
              s.search(query, c2);
            }
          } else {
            s.search(query, c2);
          }

          if (doAllGroups) {
            TopGroups<BytesRef> tempTopGroups = getTopGroups(c2, docOffset);
            groupsResult = new TopGroups<>(tempTopGroups, allGroupsCollector.getGroupCount());
          } else {
            groupsResult = getTopGroups(c2, docOffset);
          }
        } else {
          c2 = null;
          groupsResult = null;
          if (VERBOSE) {
            System.out.println("TEST:   no results");
          }
        }

        final TopGroups<BytesRef> expectedGroups =
            slowGrouping(
                groupDocs,
                searchTerm,
                getMaxScores,
                doAllGroups,
                groupSort,
                docSort,
                topNGroups,
                docsPerGroup,
                groupOffset,
                docOffset);

        if (VERBOSE) {
          if (expectedGroups == null) {
            System.out.println("TEST: no expected groups");
          } else {
            System.out.println(
                "TEST: expected groups totalGroupedHitCount="
                    + expectedGroups.totalGroupedHitCount);
            for (GroupDocs<BytesRef> gd : expectedGroups.groups) {
              System.out.println(
                  "  group="
                      + (gd.groupValue() == null ? "null" : gd.groupValue())
                      + " totalHits="
                      + gd.totalHits().value()
                      + " scoreDocs.len="
                      + gd.scoreDocs().length);
              for (ScoreDoc sd : gd.scoreDocs()) {
                System.out.println("    id=" + sd.doc + " score=" + sd.score);
              }
            }
          }

          if (groupsResult == null) {
            System.out.println("TEST: no matched groups");
          } else {
            System.out.println(
                "TEST: matched groups totalGroupedHitCount=" + groupsResult.totalGroupedHitCount);
            for (GroupDocs<BytesRef> gd : groupsResult.groups) {
              System.out.println(
                  "  group="
                      + (gd.groupValue() == null ? "null" : gd.groupValue())
                      + " totalHits="
                      + gd.totalHits().value());
              for (ScoreDoc sd : gd.scoreDocs()) {
                System.out.println("    id=" + docIDToID[sd.doc] + " score=" + sd.score);
              }
            }

            if (searchIter == 14) {
              for (int docIDX = 0; docIDX < s.getIndexReader().maxDoc(); docIDX++) {
                System.out.println(
                    "ID=" + docIDToID[docIDX] + " explain=" + s.explain(query, docIDX));
              }
            }
          }

          if (topGroupsShards == null) {
            System.out.println("TEST: no matched-merged groups");
          } else {
            System.out.println(
                "TEST: matched-merged groups totalGroupedHitCount="
                    + topGroupsShards.totalGroupedHitCount);
            for (GroupDocs<BytesRef> gd : topGroupsShards.groups) {
              System.out.println(
                  "  group="
                      + (gd.groupValue() == null ? "null" : gd.groupValue())
                      + " totalHits="
                      + gd.totalHits().value());
              for (ScoreDoc sd : gd.scoreDocs()) {
                System.out.println("    id=" + docIDToID[sd.doc] + " score=" + sd.score);
              }
            }
          }
        }

        assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true);

        // Confirm merged shards match:
        assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, true);
        if (topGroupsShards != null) {
          verifyShards(shards.docStarts, topGroupsShards);
        }

        final BlockGroupingCollector c3 =
            new BlockGroupingCollector(
                groupSort,
                groupOffset + topNGroups,
                groupSort.needsScores() || docSort.needsScores(),
                sBlocks.createWeight(
                    sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
        final AllGroupsCollector<BytesRef> allGroupsCollector2;
        final Collector c4;
        if (doAllGroups) {
          // NOTE: must be "group" and not "group_dv"
          // (groupField) because we didn't index doc
          // values in the block index:
          allGroupsCollector2 = new AllGroupsCollector<>(new TermGroupSelector("group"));
          c4 = MultiCollector.wrap(c3, allGroupsCollector2);
        } else {
          allGroupsCollector2 = null;
          c4 = c3;
        }
        // Get block grouping result:
        sBlocks.search(query, c4);
        @SuppressWarnings({"unchecked", "rawtypes"})
        final TopGroups<BytesRef> tempTopGroupsBlocks =
            (TopGroups<BytesRef>)
                c3.getTopGroups(docSort, groupOffset, docOffset, docOffset + docsPerGroup);
        final TopGroups<BytesRef> groupsResultBlocks;
        if (doAllGroups && tempTopGroupsBlocks != null) {
          assertEquals(
              (int) tempTopGroupsBlocks.totalGroupCount, allGroupsCollector2.getGroupCount());
          groupsResultBlocks =
              new TopGroups<>(tempTopGroupsBlocks, allGroupsCollector2.getGroupCount());
        } else {
          groupsResultBlocks = tempTopGroupsBlocks;
        }

        if (VERBOSE) {
          if (groupsResultBlocks == null) {
            System.out.println("TEST: no block groups");
          } else {
            System.out.println(
                "TEST: block groups totalGroupedHitCount="
                    + groupsResultBlocks.totalGroupedHitCount);
            boolean first = true;
            for (GroupDocs<BytesRef> gd : groupsResultBlocks.groups) {
              System.out.println(
                  "  group="
                      + (gd.groupValue() == null ? "null" : gd.groupValue().utf8ToString())
                      + " totalHits="
                      + gd.totalHits().value());
              for (ScoreDoc sd : gd.scoreDocs()) {
                System.out.println("    id=" + docIDToIDBlocks[sd.doc] + " score=" + sd.score);
                if (first) {
                  System.out.println("explain: " + sBlocks.explain(query, sd.doc));
                  first = false;
                }
              }
            }
          }
        }

        // Get shard'd block grouping result:
        final TopGroups<BytesRef> topGroupsBlockShards =
            searchShards(
                sBlocks,
                shardsBlocks.subSearchers,
                query,
                groupSort,
                docSort,
                groupOffset,
                topNGroups,
                docOffset,
                docsPerGroup,
                getMaxScores,
                false,
                false);

        if (expectedGroups != null) {
          // Fixup scores for reader2
          for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
            for (ScoreDoc hit : groupDocsHits.scoreDocs()) {
              final GroupDoc gd = groupDocsByID[hit.doc];
              assertEquals(gd.id, hit.doc);
              // System.out.println("fixup score " + hit.score + " to " + gd.score2 + " vs " +
              // gd.score);
              hit.score = gd.score2;
            }
          }

          final SortField[] sortFields = groupSort.getSort();
          final Map<Float, Float> termScoreMap = scoreMap.get(searchTerm);
          for (int groupSortIDX = 0; groupSortIDX < sortFields.length; groupSortIDX++) {
            if (sortFields[groupSortIDX].getType() == SortField.Type.SCORE) {
              for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
                if (groupDocsHits.groupSortValues() != null) {
                  // System.out.println("remap " + groupDocsHits.groupSortValues[groupSortIDX] + "
                  // to " + termScoreMap.get(groupDocsHits.groupSortValues[groupSortIDX]));
                  groupDocsHits.groupSortValues()[groupSortIDX] =
                      termScoreMap.get(groupDocsHits.groupSortValues()[groupSortIDX]);
                  assertNotNull(groupDocsHits.groupSortValues()[groupSortIDX]);
                }
              }
            }
          }

          final SortField[] docSortFields = docSort.getSort();
          for (int docSortIDX = 0; docSortIDX < docSortFields.length; docSortIDX++) {
            if (docSortFields[docSortIDX].getType() == SortField.Type.SCORE) {
              for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
                for (ScoreDoc _hit : groupDocsHits.scoreDocs()) {
                  FieldDoc hit = (FieldDoc) _hit;
                  if (hit.fields != null) {
                    hit.fields[docSortIDX] = termScoreMap.get(hit.fields[docSortIDX]);
                    assertNotNull(hit.fields[docSortIDX]);
                  }
                }
              }
            }
          }
        }

        assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, false);
        assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, false);
      }

      r.close();
      dir.close();

      rBlocks.close();
      dirBlocks.close();
    }
  }

  private void verifyShards(int[] docStarts, TopGroups<BytesRef> topGroups) {
    for (GroupDocs<?> group : topGroups.groups) {
      for (int hitIDX = 0; hitIDX < group.scoreDocs().length; hitIDX++) {
        final ScoreDoc sd = group.scoreDocs()[hitIDX];
        assertEquals(
            "doc=" + sd.doc + " wrong shard",
            ReaderUtil.subIndex(sd.doc, docStarts),
            sd.shardIndex);
      }
    }
  }

  private TopGroups<BytesRef> searchShards(
      IndexSearcher topSearcher,
      ShardSearcher[] subSearchers,
      Query query,
      Sort groupSort,
      Sort docSort,
      int groupOffset,
      int topNGroups,
      int docOffset,
      int topNDocs,
      boolean getMaxScores,
      boolean canUseIDV,
      boolean preFlex)
      throws Exception {

    // TODO: swap in caching, all groups collector hereassertEquals(expected.totalHitCount,
    // actual.totalHitCount);
    // too...
    if (VERBOSE) {
      System.out.println(
          "TEST: "
              + subSearchers.length
              + " shards: "
              + Arrays.toString(subSearchers)
              + " canUseIDV="
              + canUseIDV);
    }
    // Run 1st pass collector to get top groups per shard
    final Weight w =
        topSearcher.createWeight(
            topSearcher.rewrite(query),
            groupSort.needsScores() || docSort.needsScores() || getMaxScores
                ? ScoreMode.COMPLETE
                : ScoreMode.COMPLETE_NO_SCORES,
            1);
    final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>();
    List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>();
    FirstPassGroupingCollector<?> firstPassCollector = null;

    String groupField = "group";

    for (int shardIDX = 0; shardIDX < subSearchers.length; shardIDX++) {

      // First shard determines whether we use IDV or not;
      // all other shards match that:
      if (firstPassCollector == null) {
        firstPassCollector =
            createRandomFirstPassCollector(groupField, groupSort, groupOffset + topNGroups);
      } else {
        firstPassCollector =
            createFirstPassCollector(
                groupField, groupSort, groupOffset + topNGroups, firstPassCollector);
      }
      if (VERBOSE) {
        System.out.println("  shard=" + shardIDX + " groupField=" + groupField);
        System.out.println("    1st pass collector=" + firstPassCollector);
      }
      firstPassGroupingCollectors.add(firstPassCollector);
      subSearchers[shardIDX].search(w, firstPassCollector);
      final Collection<SearchGroup<BytesRef>> topGroups = getSearchGroups(firstPassCollector, 0);
      if (topGroups != null) {
        if (VERBOSE) {
          System.out.println(
              "  shard "
                  + shardIDX
                  + " s="
                  + subSearchers[shardIDX]
                  + " totalGroupedHitCount=?"
                  + " "
                  + topGroups.size()
                  + " groups:");
          for (SearchGroup<BytesRef> group : topGroups) {
            System.out.println(
                "    "
                    + groupToString(group.groupValue)
                    + " groupSort="
                    + Arrays.toString(group.sortValues));
          }
        }
        shardGroups.add(topGroups);
      }
    }

    final Collection<SearchGroup<BytesRef>> mergedTopGroups =
        SearchGroup.merge(shardGroups, groupOffset, topNGroups, groupSort);
    if (VERBOSE) {
      System.out.println(" top groups merged:");
      if (mergedTopGroups == null) {
        System.out.println("    null");
      } else {
        System.out.println("    " + mergedTopGroups.size() + " top groups:");
        for (SearchGroup<BytesRef> group : mergedTopGroups) {
          System.out.println(
              "    ["
                  + groupToString(group.groupValue)
                  + "] groupSort="
                  + Arrays.toString(group.sortValues));
        }
      }
    }

    if (mergedTopGroups != null) {
      // Now 2nd pass:
      @SuppressWarnings({"unchecked", "rawtypes"})
      final TopGroups<BytesRef>[] shardTopGroups = new TopGroups[subSearchers.length];
      for (int shardIDX = 0; shardIDX < subSearchers.length; shardIDX++) {
        final TopGroupsCollector<?> secondPassCollector =
            createSecondPassCollector(
                firstPassGroupingCollectors.get(shardIDX),
                groupField,
                mergedTopGroups,
                groupSort,
                docSort,
                docOffset + topNDocs,
                getMaxScores);
        subSearchers[shardIDX].search(w, secondPassCollector);
        shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0);
        if (VERBOSE) {
          System.out.println(
              " " + shardTopGroups[shardIDX].groups.length + " shard[" + shardIDX + "] groups:");
          for (GroupDocs<BytesRef> group : shardTopGroups[shardIDX].groups) {
            System.out.println(
                "    ["
                    + groupToString(group.groupValue())
                    + "] groupSort="
                    + Arrays.toString(group.groupSortValues())
                    + " numDocs="
                    + group.scoreDocs().length);
          }
        }
      }

      TopGroups<BytesRef> mergedGroups =
          TopGroups.merge(
              shardTopGroups,
              groupSort,
              docSort,
              docOffset,
              topNDocs,
              TopGroups.ScoreMergeMode.None);
      if (VERBOSE) {
        System.out.println(" " + mergedGroups.groups.length + " merged groups:");
        for (GroupDocs<BytesRef> group : mergedGroups.groups) {
          System.out.println(
              "    ["
                  + groupToString(group.groupValue())
                  + "] groupSort="
                  + Arrays.toString(group.groupSortValues())
                  + " numDocs="
                  + group.scoreDocs().length);
        }
      }
      return mergedGroups;
    } else {
      return null;
    }
  }

  private void assertEquals(
      int[] docIDtoID,
      TopGroups<BytesRef> expected,
      TopGroups<BytesRef> actual,
      boolean verifyGroupValues,
      boolean verifyTotalGroupCount,
      boolean idvBasedImplsUsed) {
    if (expected == null) {
      assertNull(actual);
      return;
    }
    assertNotNull(actual);

    assertEquals(
        "expected.groups.length != actual.groups.length",
        expected.groups.length,
        actual.groups.length);
    assertEquals(
        "expected.totalHitCount != actual.totalHitCount",
        expected.totalHitCount,
        actual.totalHitCount);
    assertEquals(
        "expected.totalGroupedHitCount != actual.totalGroupedHitCount",
        expected.totalGroupedHitCount,
        actual.totalGroupedHitCount);
    if (expected.totalGroupCount != null && verifyTotalGroupCount) {
      assertEquals(
          "expected.totalGroupCount != actual.totalGroupCount",
          expected.totalGroupCount,
          actual.totalGroupCount);
    }

    for (int groupIDX = 0; groupIDX < expected.groups.length; groupIDX++) {
      if (VERBOSE) {
        System.out.println("  check groupIDX=" + groupIDX);
      }
      final GroupDocs<BytesRef> expectedGroup = expected.groups[groupIDX];
      final GroupDocs<BytesRef> actualGroup = actual.groups[groupIDX];
      if (verifyGroupValues) {
        if (idvBasedImplsUsed) {
          if (actualGroup.groupValue().length == 0) {
            assertNull(expectedGroup.groupValue());
          } else {
            assertEquals(expectedGroup.groupValue(), actualGroup.groupValue());
          }
        } else {
          assertEquals(expectedGroup.groupValue(), actualGroup.groupValue());
        }
      }
      assertArrayEquals(expectedGroup.groupSortValues(), actualGroup.groupSortValues());

      // TODO
      // assertEquals(expectedGroup.maxScore, actualGroup.maxScore);
      assertEquals(expectedGroup.totalHits().value(), actualGroup.totalHits().value());

      final ScoreDoc[] expectedFDs = expectedGroup.scoreDocs();
      final ScoreDoc[] actualFDs = actualGroup.scoreDocs();

      assertEquals(expectedFDs.length, actualFDs.length);
      for (int docIDX = 0; docIDX < expectedFDs.length; docIDX++) {
        final FieldDoc expectedFD = (FieldDoc) expectedFDs[docIDX];
        final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX];
        // System.out.println("  actual doc=" + docIDtoID[actualFD.doc] + " score=" +
        // actualFD.score);
        assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]);
        assertArrayEquals(expectedFD.fields, actualFD.fields);
      }
    }
  }

  private static class ShardSearcher extends IndexSearcher {
    private final LeafReaderContext ctx;

    public ShardSearcher(LeafReaderContext ctx, IndexReaderContext parent) {
      super(parent);
      this.ctx = ctx;
    }

    public void search(Weight weight, Collector collector) throws IOException {
      searchLeaf(ctx, 0, DocIdSetIterator.NO_MORE_DOCS, weight, collector);
    }

    @Override
    public String toString() {
      return "ShardSearcher(" + ctx.reader() + ")";
    }
  }
}
